Deep Fake Detection¶

Image face classification with Fake and Real faces

Dataset from Kaggle link

Importing Libraries¶

Все необходимые для запуска библиотеки, а также константы (необходимо установить подходящие значения)

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import torchvision
from torchvision import models as torch_models
from torchvision.transforms import v2 as transforms

import vit_pytorch

from torchinfo import summary as info_summary

import os
import time
import json
import random
import warnings
import numpy as np
import pandas as pd
from sklearn import metrics
import matplotlib.pyplot as plt


labels_map = {0: "Fake", 1: "Real"}
labels_map_reversed = dict(zip(labels_map.values(), labels_map.keys()))

# <- Enter your path
DATASET_PATH = "S://DataSet/DeepFake"
# DATASET_PATH = "D://BigDataSets/DeepFake"
# DATASET_PATH = "D://BigFiles/DataSets/DeepFake"
MODEL_DIR = "./models"

transform_base = transforms.Compose([
    transforms.ToImage(),  # transforms.ToTensor()
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

warnings.filterwarnings('ignore', category=FutureWarning)

# Debug mode
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Boost speed
USE_AMP = True
torch.backends.cudnn.benchmark = True  
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch version: {torch.__version__}\nDevice: {device}")
Using PyTorch version: 2.5.1
Device: cuda

Utils¶

Полезные для мониторинга, вывода и контроля обучения функции

In [3]:
def progress_bar(n_iter, n_total, prefix='Progress: ', suffix='', length=55, fill='█', lost='-', percent=True):
    percent = f"{100 * (n_iter / float(n_total)) :.1f}% " if percent else ""
    filled_length = round(length * n_iter // n_total)
    bar = fill * filled_length + lost * (length - filled_length)
    print(f'\r{prefix}[{n_iter}/{n_total}] |{bar}| {percent}{suffix}', end=' ' * 10)
    if n_iter == n_total:
        print()


class AddGaussianNoise:
    def __init__(self, mean=0., std=1., p=1.):
        self.std = std
        self.mean = mean
        self.p = p

    def __call__(self, tensor):
        if np.random.rand() > self.p:
            return tensor

        return tensor + (torch.randn(tensor.size()) * self.std + self.mean) * torch.randn([1])

    def __repr__(self):
        return self.__class__.__name__ + f'({self.mean = }, {self.std = })'


class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.best_loss = 1e9
        self.counter = 0
        self.early_stop = False

    def step(self, val_loss):
        if (val_loss - self.best_loss) > self.min_delta:
            self.counter += 1
            if self.counter >= self.tolerance:  
                self.early_stop = True

        self.best_loss = min(val_loss, self.best_loss)
    
    def __bool__(self):
        return self.early_stop
    
class SaveOnlyBestModel:
    def __init__(self, model_ref: nn.Module, **kwargs):
        self.best_val_loss = 1e9
        self.model = model_ref
        self.kwargs = kwargs

    def step(self, val_loss):
        if (val_loss - self.best_val_loss) < 0.:
            save_model(self.model, **self.kwargs)
        self.best_val_loss = min(val_loss, self.best_val_loss)
        
class StairAccumulateGradient:
    def __init__(self, stairs_dict: dict):
        self.stairs_dict = stairs_dict
        self.current = 1
    
    def __call__(self, epoch: int):
        self.current = self.stairs_dict.get(epoch, self.current)
        return self.current
    
    
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
    """persistent_workers=True analog"""
    class RepeatSampler(object):
        def __init__(self, sampler):
            self.sampler = sampler
    
        def __iter__(self):
            while True:
                yield from iter(self.sampler)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = MultiEpochsDataLoader.RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)


def save_model(model_: nn.Module, name="", only_weights=False):
    name = name if name else model_.name
    if only_weights:
        torch.save(model_.state_dict(), os.path.join(MODEL_DIR, f"{name}_weights.pth"))
    else:
        torch.save(model_, os.path.join(MODEL_DIR, f"{name}.pth"))

def load_model_weights(model_class, name="") -> nn.Module:
    name = name if name else model_class.base_name
    model_ = model_class()
    model_.load_state_dict(torch.load(os.path.join(MODEL_DIR, f"{name}_weights.pth")))
    model_.to(device).eval()
    return model_

def load_model(name: str) -> nn.Module:
    model_ = torch.load(os.path.join(MODEL_DIR, f"{name}.pth")).to(device).eval()
    return model_


def show_history(history_):
    plt.figure(figsize=(13, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history_["train_losses"], label='Training loss')
    plt.plot(history_["val_losses"], label='Validation loss')
    plt.legend(frameon=False)
    plt.title("Loss on training")
    
    plt.subplot(1, 2, 2)
    plt.plot(history_["train_accuracy"], label='Training accuracy')
    plt.plot(history_["val_accuracy"], label='Validation accuracy')
    plt.legend(frameon=False)
    plt.title("Accuracy on training")
    plt.show()

    
def show_from_loader(loader_or_batch: torch.utils.data.DataLoader, is_show_predict=False,
                     rows=1, cols=8, augmentation=None):
    images, labels = next(iter(loader_or_batch))
    images, labels = images[:rows*cols], labels[:rows*cols]
    if augmentation:
        images = augmentation(images)

    if is_show_predict:
        _, outputs = forward_step(images, labels)
    
    # Normalize to imshow
    images -= images.min()
    images /= images.max()
    images = np.clip(images, 0, 1)
    labels = labels.numpy()
    # outputs = outputs.round(2)
    
    plt.figure(figsize=(int(2.3*cols), int(2.7*rows)))
    for i in range(cols*rows):
        plt.subplot(rows, cols, i + 1)
        image = images[i]
        image = image.numpy().transpose((1, 2, 0))
        image = np.clip(image, 0., 1.)
        plt.imshow(image)
        plt.xticks([])
        plt.yticks([])
        if is_show_predict:
            plt.title(f"{labels_map[labels[i]]} ({outputs[i] :.3f})",
                      c=('g' if labels[i] == round(outputs[i]) else 'r'))
        else:
            plt.title(labels_map[labels[i]])
    
    plt.show()
    print(f"Min={images.min()} Max={images.max()} Mean={images.mean()} Std={images.std()}")

Dataset¶

Learn and load data

Make DataSet¶

Создание тренировочного и валидационного загрузчика

In [14]:
def get_dataset_of_category(category: str) -> torch.utils.data.Dataset:
    return torchvision.datasets.ImageFolder(
        root=os.path.join(DATASET_PATH, category),
        transform=transform_base)


dataset_train = get_dataset_of_category("Train")
dataset_val   = get_dataset_of_category("Validation")
dataset_test  = get_dataset_of_category("Test")

load_options = {'shuffle': True,}
if str(device) == 'cuda':
    load_options |= {'drop_last': True, 'pin_memory': True,
                     'prefetch_factor': 3, 'num_workers': 8, 'persistent_workers': True}
    

loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128, **load_options)
loader_val   = torch.utils.data.DataLoader(dataset_val,  batch_size=256, **load_options)
loader_test  = torch.utils.data.DataLoader(dataset_test, batch_size=256, **load_options)
# loader = MultiEpochsDataLoader(dataset, ...)


print(dataset_train.class_to_idx,
      (device, load_options),
      (loader_train.pin_memory, loader_train.prefetch_factor, loader_train.num_workers),
      (len(dataset_train), len(dataset_val), len(dataset_test)),
      (loader_train.batch_size, loader_val.batch_size, loader_test.batch_size),
      sep='\n')
{'Fake': 0, 'Real': 1}
(device(type='cuda'), {'shuffle': True, 'drop_last': True, 'pin_memory': True, 'prefetch_factor': 3, 'num_workers': 8, 'persistent_workers': True})
(True, 3, 8)
(140002, 39428, 10921)
(128, 256, 256)
In [15]:
for imgs, labels in loader_train:
    print(f"Min={imgs.min()} Max={imgs.max()} Mean={imgs.mean()} Std={imgs.std()}")
    print(imgs.shape, labels[:10])
    print((imgs[0, 0], labels[0]))
    break
    
show_from_loader(loader_train, rows=1, cols=10)
Min=-2.1179039478302 Max=2.640000104904175 Mean=-0.17527608573436737 Std=1.1949788331985474
torch.Size([128, 3, 256, 256]) tensor([1, 1, 0, 1, 0, 0, 1, 0, 1, 0])
(tensor([[ 1.7009,  1.8550,  2.0263,  ...,  1.6324,  1.6324,  1.6495],
        [ 1.7694,  1.9064,  2.0605,  ...,  1.7009,  1.6667,  1.6495],
        [ 1.8893,  1.9920,  2.0948,  ...,  1.6838,  1.6153,  1.5639],
        ...,
        [-1.3302, -1.3130, -1.3130,  ..., -1.9809, -1.9980, -1.9980],
        [-1.3130, -1.3130, -1.2959,  ..., -1.9809, -1.9980, -1.9980],
        [-1.2959, -1.2959, -1.2788,  ..., -1.9809, -1.9980, -2.0152]]), tensor(1))
No description has been provided for this image
Min=0.0 Max=1.0 Mean=0.36924514174461365 Std=0.24825681746006012

Augmentation for data¶

loader -> transform_base -> (transform_train_vN -> train_model) -> pred -> train
loader -> transform_base ->             predict_model           -> pred
In [6]:
# transforms.Resize(size=(256, 256))
# transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.35)
# transforms.ColorJitter(brightness=color_alpha, contrast=color_alpha, hue=0, saturation=color_alpha),  color_alpha = 0.22
# transforms.RandomPerspective(distortion_scale=0.2, p=0.2),

# transforms = v2.Compose([
#     v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
#     v2.ToDtype(torch.uint8, scale=True),  # optional, most input are already uint8 at this point
#     ....................
#     v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
#     v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])


transform_train_v1 = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
])

transform_train_v2 = transforms.Compose([
    transforms.RandAugment(), transforms.RandomHorizontalFlip(p=0.5),])

transform_train_v3 = transforms.Compose([
    transforms.AugMix(), transforms.RandomHorizontalFlip(p=0.5),])

transform_train_v4 = transforms.Compose([
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
    transforms.RandomHorizontalFlip(p=0.5),])

transform_train_v5 = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.4),
    transforms.ColorJitter(brightness=0.07, contrast=0.07, hue=0.07, saturation=0.07),
    transforms.RandomRotation(degrees=7),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.2),
])  # AddGaussianNoise(mean=0.0, std=0.05, p=0.07),

# 1) 100% Fast
# 2) 70% Mean
# 3) 30% Slooow...
# 4) 60% Mean
# 5) 15% Too very slooow...

Explore dataset¶

Наглядное сравнение разных аугментаций

In [ ]:
simple_batch = next(iter(loader_train))

show_from_loader(simple_batch, rows=1, cols=10)
show_from_loader(simple_batch, augmentation=transform_train_v1, rows=1, cols=10)
show_from_loader(simple_batch, augmentation=transform_train_v2, rows=1, cols=10)
show_from_loader(simple_batch, augmentation=transform_train_v3, rows=1, cols=10)
show_from_loader(simple_batch, augmentation=transform_train_v4, rows=1, cols=10)
show_from_loader(simple_batch, augmentation=transform_train_v5, rows=1, cols=10)

Train cycle¶

In [7]:
def binary_accuracy_count(outputs, labels):
    return (np.round(outputs) == labels.numpy()).sum()


def forward_step(images, labels, augmentation=None):
    if augmentation:
        images = augmentation(images)
    images, labels = images.to(device), labels.to(device).float()
    
    with torch.autocast(device_type=str(device), dtype=torch.float16, enabled=USE_AMP):
        outputs = model(images)
        loss = loss_function(outputs, labels)
    
    return loss, F.sigmoid(outputs.data).detach().cpu().numpy()
In [8]:
def run_loader(loader: torch.utils.data.DataLoader, epoch=0, type_run="<type>",
               is_train=False, augmentation=None, options={}):
    # Metrics
    sum_loss = accuracy_count = 0
    
    # Params
    accum_grad = options.get('gradient_accumulation_steps', StairAccumulateGradient({}))(epoch)
    grad_clip = options.get('grad_clip', None)
    
    # Run
    for i, (images, labels) in enumerate(loader):
        loss, outputs = forward_step(images, labels, augmentation)
        
        # Backward propagation
        if is_train:
            loss /= accum_grad
            stat_batch_size.scale(loss).backward()
            
            if i % accum_grad == 0:
                if grad_clip is not None:  # Gradient Clip
                    stat_batch_size.unscale_(optimizer)
                    torch.nn.utils.clip_grad_value_(model.parameters(), grad_clip)
                
                # Update scaler params
                stat_batch_size.step(optimizer)
                stat_batch_size.update()
                
                # Prevent accumulation of gradients
                optimizer.zero_grad(set_to_none=True)
            
            loss *= accum_grad

        # Metrics - Loss & Accuracy
        sum_loss += loss.item()
        correct_count = binary_accuracy_count(outputs, labels)
        accuracy_count += correct_count

        progress_bar(i, len(loader), prefix=f"Epoch {epoch} {type_run}: ",
                     suffix=f"loss: {loss.item() :.3f}\taccuracy: {correct_count / len(labels):.3f}")
    
    return sum_loss / len(loader), accuracy_count / len(loader.dataset)


def train(epochs: int, augmentation=None, callbacks={}, options={}):
    train_losses, val_losses = [], []
    train_accuracy, val_accuracy = [], []
    optimizer.zero_grad(set_to_none=True)
    
    # Epochs
    for epoch in range(1, epochs+1):
        start_time = time.time()
        
        # Model Training
        model.train(True)
        train_loss, train_accuracy_epoch = run_loader(
            loader_train, epoch=epoch, type_run='train',
            is_train=True, augmentation=augmentation, options=options
        )
    
        # Model Validation
        model.train(False)
        with torch.inference_mode():  # torch.no_grad()
            val_loss, val_accuracy_epoch = run_loader(
                loader_val, epoch=epoch, type_run='val',
                is_train=False
            )

        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        train_accuracy.append(train_accuracy_epoch)
        val_accuracy.append(val_accuracy_epoch)
        
        progress_bar(len(loader_train), len(loader_train), length=10, percent=False,
                     prefix=f"Epoch {epoch} ({time.time() - start_time :.0f} sec): ",
                     suffix=f"loss: {train_losses[-1] :.4f}\taccuracy: {train_accuracy[-1] :.3f}\t"
                            f"val_loss: {val_losses[-1] :.4f}\tval_accuracy: {val_accuracy[-1] :.3f}\t" +
                            (f"Lr (e-3): {callbacks['lr_scheduler'].optimizer.param_groups[0]['lr'] * 1e3:.2f}"
                             if callbacks.get("lr_scheduler") else ""))
        
        # Callbacks
        if "early_stopping" in callbacks:
            callbacks["early_stopping"].step(val_losses[-1])
            if callbacks["early_stopping"]:
                print("INFO: Early stopping!")
                break
        if "lr_scheduler" in callbacks:
            callbacks["lr_scheduler"].step(val_losses[-1])
        if "save_best_model" in callbacks:
            callbacks["save_best_model"].step(val_losses[-1])


    return {"train_losses": train_losses,
            "val_losses": val_losses,
            "val_accuracy": val_accuracy,
            "train_accuracy": train_accuracy,}

Model architecture¶

In [9]:
class BaseModel:
    def __init__(self):
        self.name = "BaseModel"
        self.input_size = (256, 256)
        self.input_chanels = 3
        self.out_features = 1
In [10]:
class Model_SqueezeNeXt(nn.Module, BaseModel):
    # https://sh-tsang.medium.com/reading-squeezenext-hardware-aware-neural-network-design-image-classification-3fc8d1d3f76
    # https://github.com/osmr/imgclsmob/blob/c03fa67de3c9e454e9b6d35fe9cbb6b15c28fda7/pytorch/pytorchcv/models/squeezenext.py
    # https://arxiv.org/pdf/1803.10615v2.pdf
    class SqueezeNeXtUnit(nn.Module):
        def __init__(self,
                     in_channels,
                     out_channels,
                     stride):
            super().__init__()
            if stride == 2:
                reduction_den = 1
                self.resize_identity = True
            elif in_channels > out_channels:
                reduction_den = 4
                self.resize_identity = True
            else:
                reduction_den = 2
                self.resize_identity = False
                
            self.use_bias = False
            self.act = nn.ReLU(inplace=True)
            
            self.conv_x5 = nn.Sequential(
                # Conv 1 (1x1)
                nn.Conv2d(in_channels=in_channels, out_channels=(in_channels // reduction_den),
                          kernel_size=1, bias=self.use_bias, stride=stride),
                nn.BatchNorm2d(num_features=(in_channels // reduction_den)), self.act,
                # Conv 2 (1x1)
                nn.Conv2d(in_channels=(in_channels // reduction_den), out_channels=(in_channels // (2 * reduction_den)),
                          kernel_size=1, bias=self.use_bias),
                nn.BatchNorm2d(num_features=(in_channels // (2 * reduction_den))), self.act,
                # Conv 3 (1x3)
                nn.Conv2d(in_channels=(in_channels // (2 * reduction_den)), out_channels=(in_channels // reduction_den),
                          kernel_size=(1, 3), padding=(0, 1), bias=self.use_bias),
                nn.BatchNorm2d(num_features=(in_channels // reduction_den)), self.act,
                # Conv 4 (3x1)
                nn.Conv2d(in_channels=(in_channels // reduction_den), out_channels=(in_channels // reduction_den),
                          kernel_size=(3, 1), padding=(1, 0), bias=self.use_bias),
                nn.BatchNorm2d(num_features=(in_channels // reduction_den)), self.act,
                # Conv 5 (1x1)
                nn.Conv2d(in_channels=(in_channels // reduction_den), out_channels=out_channels,
                          kernel_size=1, bias=self.use_bias),
                nn.BatchNorm2d(num_features=out_channels), self.act,
            )

            if self.resize_identity:
                self.identity_conv = nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                              stride=stride, bias=self.use_bias),
                    nn.BatchNorm2d(num_features=out_channels), self.act
                )
    
        def forward(self, x):
            if self.resize_identity:
                identity = self.identity_conv(x)
            else:
                identity = x
            x = self.conv_x5(x)
            x = x + identity
            return self.act(x)

    def __init__(self, name="v1",      # Name iteration of model
                 width_scale=1.,       # Wide of model
                 layers=[2, 4, 14, 1]  # From original SqNxt-23
                 ):
        # super().__init__()
        nn.Module.__init__(self)
        BaseModel.__init__(self)
        
        self.name = f"{width_scale:.1f}-SqNxt-{sum(layers)+2}_{name}"
        
        # SqNxt-23v5
        # num_classes = 3
        init_block_channels = int(64 * width_scale)
        final_block_channels = int(128 * width_scale)
        channels_per_layers = [32, 64, 128, 256]
        channels = [[int(ci * width_scale)] * li for (ci, li) in zip(channels_per_layers, layers)]
        print(self.name, [init_block_channels], *channels, [final_block_channels])
        
        # Make arch
        self.features = nn.Sequential()
        self.features.add_module("first_block", nn.Sequential(
            nn.Conv2d(3, init_block_channels, kernel_size=5, stride=2),
            nn.BatchNorm2d(init_block_channels), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
        ))
        
        in_channels = init_block_channels
        for i, channels_per_stage in enumerate(channels):
            stage = nn.Sequential()
            for j, out_channels in enumerate(channels_per_stage):
                stride = 2 if (j == 0) and (i != 0) else 1
                stage.add_module(f"unit_{j + 1}", Model_SqueezeNeXt.SqueezeNeXtUnit(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    stride=stride))
                in_channels = out_channels
            self.features.add_module(f"stage_{i + 1}", stage)
        
        self.features.add_module("final_block", nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=final_block_channels, kernel_size=1, bias=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        ))

        self.output = nn.Sequential(
            nn.Linear(in_features=final_block_channels, out_features=self.out_features),
            # nn.Sigmoid(), nn.Flatten(start_dim=0),  # Add for binary classification by me (CE -> BCE loss)
            nn.Flatten(start_dim=0),  # Because Loss is BCE+Logits (for correct work AMP)
        )

    def forward(self, x):
        # x - must be normalized (mean=0, std=1)
        # x = transform_model_forward(x)
        x = self.features(x)
        x = self.output(x)
        return x
In [11]:
class Model_ResNet(nn.Module, BaseModel):
    def __init__(self, name="v1", is_pretrained=False, freeze_grad=False):
        # super().__init__()
        nn.Module.__init__(self)
        BaseModel.__init__(self)
        
        self.name = f"ResNet-34{'pre' if is_pretrained else ''}_{name}"
        
        # resnet18 resnet34 resnet50 resnet101
        self._pretrain_model = torch_models.resnet34(pretrained=is_pretrained)
        self._pretrain_model.requires_grad_(not freeze_grad)
        
        layers = list(self._pretrain_model.children())[:-1]
        layers.append(nn.Flatten())
        
        self.features = nn.Sequential(*layers)
        
        self.output = nn.Sequential(
            nn.Linear(
                in_features=self._pretrain_model.fc.in_features,
                out_features=self.out_features),
            nn.Flatten(start_dim=0)
        )

    def forward(self, x):  # x - must be normalized (mean=0, std=1)
        x = self.features(x)
        x = self.output(x)
        return x
In [12]:
class Model_ViT(nn.Module, BaseModel):
    def __init__(self, name="v1"):
        # super().__init__()
        nn.Module.__init__(self)
        BaseModel.__init__(self)
        
        # dim, m_dim = 1024, 2048
        # depth, heads = 6, 16
        dim, m_dim = 512, 1024
        depth, heads = 4, 12
        
        self.name = f"ViT-{depth}-{heads}-{dim}-{m_dim}_{name}"
        
        self.model = vit_pytorch.SimpleViT(
            image_size = 256,
            patch_size = 32,
            num_classes = 1,
            depth = depth,
            heads = heads,
            dim = dim,
            mlp_dim = m_dim,
            # dropout = 0.1,
            # emb_dropout = 0.1
        )
        # SimpleViT no dropout
        
        self.output = nn.Sequential(
            nn.Flatten(start_dim=0)
        )

    def forward(self, x):  # x - must be normalized (mean=0, std=1)
        x = self.model(x)
        x = self.output(x)
        return x

Training model¶

In [13]:
# Model
model = Model_SqueezeNeXt(
    name="v1",
    width_scale=1.2,      # 1.0
    layers=[6, 7, 5, 0]   # [2, 4, 14, 1]
)

# model = Model_ResNet(
#     name="v1",
#     is_pretrained=True,
#     freeze_grad=False
# )

# model = Model_ViT(name="nogit")
# model = load_model("1.0-SqNxt-17")

model.to(device)


# Summary model
with torch.autocast(device_type=str(device), dtype=torch.float16, enabled=USE_AMP):
    print(str(info_summary(model, input_size=(loader_train.batch_size, 3, 256, 256),
                           depth=3, device=device, mode="train",  # col_width=20 depth=5
                           col_names=["num_params", "mult_adds", "output_size"], row_settings=["var_names"])))


# Define our loss function
loss_function = nn.BCEWithLogitsLoss()  # nn.BCELoss()  # nn.CrossEntropyLoss()


# Define the optimizer
learning_rate = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)  # Adam RMSprop Adafactor


# AMP (Auto Mixed Precision)
stat_batch_size = torch.amp.GradScaler(device=device, enabled=USE_AMP)

# Callbacks & Schedulers
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.7, min_lr=3e-5)
early_stopping = EarlyStopping(tolerance=11, min_delta=0.)
save_best_model = SaveOnlyBestModel(model)
callbacks = {
    "lr_scheduler": lr_scheduler,
    "early_stopping": early_stopping,
    "save_best_model": save_best_model,
}

# Options
options = {
    "gradient_accumulation_steps": StairAccumulateGradient({1:2, 2:3, 5:4, 10:5, 16:6}),
    # "gradient_accumulation_steps": StairAccumulateGradient({1:1, 2:2, 5:3, 10:4, 16:5}),
    "grad_clip": None,  # Not improve (0.1)  -  grad_clip_value
}
1.2-SqNxt-20_v1 [76] [38, 38, 38, 38, 38, 38] [76, 76, 76, 76, 76, 76, 76] [153, 153, 153, 153, 153] [] [153]
=============================================================================================================================
Layer (type (var_name))                            Param #                   Mult-Adds                 Output Shape
=============================================================================================================================
Model_SqueezeNeXt (Model_SqueezeNeXt)              --                        --                        [64]
├─Sequential (features)                            --                        --                        [64, 153]
│    └─Sequential (first_block)                    --                        --                        [64, 76, 63, 63]
│    │    └─Conv2d (0)                             5,776                     5,868,785,664             [64, 76, 126, 126]
│    │    └─BatchNorm2d (1)                        152                       9,728                     [64, 76, 126, 126]
│    │    └─ReLU (2)                               --                        --                        [64, 76, 126, 126]
│    │    └─MaxPool2d (3)                          --                        --                        [64, 76, 63, 63]
│    └─Sequential (stage_1)                        --                        --                        [64, 38, 63, 63]
│    │    └─SqueezeNeXtUnit (unit_1)               7,105                     1,732,661,312             [64, 38, 63, 63]
│    │    └─SqueezeNeXtUnit (unit_2)               3,419                     815,658,688               [64, 38, 63, 63]
│    │    └─SqueezeNeXtUnit (unit_3)               3,419                     815,658,688               [64, 38, 63, 63]
│    │    └─SqueezeNeXtUnit (unit_4)               3,419                     815,658,688               [64, 38, 63, 63]
│    │    └─SqueezeNeXtUnit (unit_5)               3,419                     815,658,688               [64, 38, 63, 63]
│    │    └─SqueezeNeXtUnit (unit_6)               3,419                     815,658,688               [64, 38, 63, 63]
│    └─Sequential (stage_2)                        --                        --                        [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_1)               15,010                    946,376,320               [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_2)               13,414                    851,732,608               [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_3)               13,414                    851,732,608               [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_4)               13,414                    851,732,608               [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_5)               13,414                    851,732,608               [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_6)               13,414                    851,732,608               [64, 76, 32, 32]
│    │    └─SqueezeNeXtUnit (unit_7)               13,414                    851,732,608               [64, 76, 32, 32]
│    └─Sequential (stage_3)                        --                        --                        [64, 153, 16, 16]
│    │    └─SqueezeNeXtUnit (unit_1)               59,056                    948,903,424               [64, 153, 16, 16]
│    │    └─SqueezeNeXtUnit (unit_2)               52,974                    854,249,856               [64, 153, 16, 16]
│    │    └─SqueezeNeXtUnit (unit_3)               52,974                    854,249,856               [64, 153, 16, 16]
│    │    └─SqueezeNeXtUnit (unit_4)               52,974                    854,249,856               [64, 153, 16, 16]
│    │    └─SqueezeNeXtUnit (unit_5)               52,974                    854,249,856               [64, 153, 16, 16]
│    └─Sequential (stage_4)                        --                        --                        [64, 153, 16, 16]
│    └─Sequential (final_block)                    --                        --                        [64, 153]
│    │    └─Conv2d (0)                             23,562                    386,039,808               [64, 153, 16, 16]
│    │    └─AdaptiveAvgPool2d (1)                  --                        --                        [64, 153, 1, 1]
│    │    └─Flatten (2)                            --                        --                        [64, 153]
├─Sequential (output)                              --                        --                        [64]
│    └─Linear (0)                                  154                       9,856                     [64, 1]
│    └─Flatten (1)                                 --                        --                        [64]
=============================================================================================================================
Total params: 474,593
Trainable params: 474,593
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 22.49
=============================================================================================================================
Input size (MB): 50.33
Forward/backward pass size (MB): 3074.59
Params size (MB): 1.68
Estimated Total Size (MB): 3126.60
=============================================================================================================================
In [ ]:
print(model.name)

history = train(epochs=100, augmentation=transform_train_v1, callbacks=callbacks, options=options)
show_history(history)
In [44]:
print(model.name)

history = train(epochs=100, augmentation=transform_train_v1, callbacks=callbacks, options=options)
show_history(history)
1.0-SqNxt-15_v1
Epoch 1 (258 sec): [546/546] |██████████| loss: 0.5445	accuracy: 0.706	val_loss: 0.4056	val_accuracy: 0.812	Lr (e-3): 1.00          
Epoch 2 (255 sec): [546/546] |██████████| loss: 0.2395	accuracy: 0.898	val_loss: 0.4983	val_accuracy: 0.806	Lr (e-3): 1.00          
Epoch 3 (193 sec): [546/546] |██████████| loss: 0.1635	accuracy: 0.932	val_loss: 0.2146	val_accuracy: 0.914	Lr (e-3): 1.00          
Epoch 4 (188 sec): [546/546] |██████████| loss: 0.1281	accuracy: 0.947	val_loss: 0.2200	val_accuracy: 0.911	Lr (e-3): 1.00          
Epoch 5 (188 sec): [546/546] |██████████| loss: 0.0970	accuracy: 0.960	val_loss: 0.1841	val_accuracy: 0.928	Lr (e-3): 1.00          
Epoch 6 (193 sec): [546/546] |██████████| loss: 0.0871	accuracy: 0.963	val_loss: 0.1607	val_accuracy: 0.938	Lr (e-3): 1.00          
Epoch 7 (235 sec): [546/546] |██████████| loss: 0.0816	accuracy: 0.966	val_loss: 0.2300	val_accuracy: 0.916	Lr (e-3): 1.00          
Epoch 8 (195 sec): [546/546] |██████████| loss: 0.0777	accuracy: 0.967	val_loss: 0.1728	val_accuracy: 0.937	Lr (e-3): 1.00          
Epoch 9 (188 sec): [546/546] |██████████| loss: 0.0660	accuracy: 0.972	val_loss: 0.1663	val_accuracy: 0.939	Lr (e-3): 0.70          
Epoch 10 (188 sec): [546/546] |██████████| loss: 0.0561	accuracy: 0.976	val_loss: 0.1615	val_accuracy: 0.943	Lr (e-3): 0.70          
Epoch 11 (187 sec): [546/546] |██████████| loss: 0.0480	accuracy: 0.979	val_loss: 0.1517	val_accuracy: 0.948	Lr (e-3): 0.49          
Epoch 12 (188 sec): [546/546] |██████████| loss: 0.0452	accuracy: 0.980	val_loss: 0.1423	val_accuracy: 0.954	Lr (e-3): 0.49          
Epoch 13 (187 sec): [546/546] |██████████| loss: 0.0438	accuracy: 0.981	val_loss: 0.1698	val_accuracy: 0.946	Lr (e-3): 0.49          
Epoch 14 (187 sec): [546/546] |██████████| loss: 0.0419	accuracy: 0.982	val_loss: 0.1805	val_accuracy: 0.943	Lr (e-3): 0.49          
Epoch 15 (187 sec): [546/546] |██████████| loss: 0.0370	accuracy: 0.984	val_loss: 0.1458	val_accuracy: 0.955	Lr (e-3): 0.34          
Epoch 16 (188 sec): [546/546] |██████████| loss: 0.0342	accuracy: 0.985	val_loss: 0.1535	val_accuracy: 0.954	Lr (e-3): 0.34          
INFO: Early stopping!
No description has been provided for this image
In [22]:
print(model.name)

history = train(epochs=100, augmentation=transform_train_v1, callbacks=callbacks, options=options)
show_history(history)
Epoch 1 (532 sec): [2187/2187] |██████████| loss: 0.4425	accuracy: 0.770	val_loss: 0.3558	val_accuracy: 0.854	Lr (e-3): 1.00          
Epoch 2 (497 sec): [2187/2187] |██████████| loss: 0.1791	accuracy: 0.929	val_loss: 0.2158	val_accuracy: 0.916	Lr (e-3): 1.00          
Epoch 3 (490 sec): [2187/2187] |██████████| loss: 0.1196	accuracy: 0.952	val_loss: 0.1929	val_accuracy: 0.922	Lr (e-3): 1.00          
Epoch 4 (495 sec): [2187/2187] |██████████| loss: 0.1037	accuracy: 0.959	val_loss: 0.1740	val_accuracy: 0.937	Lr (e-3): 1.00          
Epoch 5 (492 sec): [2187/2187] |██████████| loss: 0.0920	accuracy: 0.963	val_loss: 0.1574	val_accuracy: 0.937	Lr (e-3): 1.00          
Epoch 6 (491 sec): [2187/2187] |██████████| loss: 0.0838	accuracy: 0.966	val_loss: 0.1521	val_accuracy: 0.943	Lr (e-3): 1.00          
Epoch 7 (493 sec): [2187/2187] |██████████| loss: 0.0794	accuracy: 0.968	val_loss: 0.1307	val_accuracy: 0.947	Lr (e-3): 1.00          
Epoch 8 (483 sec): [2187/2187] |██████████| loss: 0.0722	accuracy: 0.971	val_loss: 0.1791	val_accuracy: 0.928	Lr (e-3): 1.00          
Epoch 9 (494 sec): [2187/2187] |██████████| loss: 0.0615	accuracy: 0.975	val_loss: 0.1185	val_accuracy: 0.955	Lr (e-3): 1.00          
Epoch 10 (495 sec): [2187/2187] |██████████| loss: 0.0587	accuracy: 0.976	val_loss: 0.1181	val_accuracy: 0.958	Lr (e-3): 1.00          
Epoch 11 (517 sec): [2187/2187] |██████████| loss: 0.0553	accuracy: 0.978	val_loss: 0.1282	val_accuracy: 0.953	Lr (e-3): 1.00          
Epoch 12 (508 sec): [2187/2187] |██████████| loss: 0.0540	accuracy: 0.978	val_loss: 0.1287	val_accuracy: 0.955	Lr (e-3): 1.00          
Epoch 13 (499 sec): [2187/2187] |██████████| loss: 0.0444	accuracy: 0.982	val_loss: 0.0977	val_accuracy: 0.965	Lr (e-3): 0.68          
Epoch 14 (493 sec): [2187/2187] |██████████| loss: 0.0422	accuracy: 0.982	val_loss: 0.1061	val_accuracy: 0.961	Lr (e-3): 0.68          
Epoch 15 (487 sec): [2187/2187] |██████████| loss: 0.0368	accuracy: 0.985	val_loss: 0.1065	val_accuracy: 0.962	Lr (e-3): 0.68          
Epoch 16 (487 sec): [2187/2187] |██████████| loss: 0.0303	accuracy: 0.988	val_loss: 0.0931	val_accuracy: 0.968	Lr (e-3): 0.46          
Epoch 17 (488 sec): [2187/2187] |██████████| loss: 0.0276	accuracy: 0.988	val_loss: 0.0975	val_accuracy: 0.970	Lr (e-3): 0.46          
Epoch 18 (489 sec): [2187/2187] |██████████| loss: 0.0270	accuracy: 0.988	val_loss: 0.1019	val_accuracy: 0.967	Lr (e-3): 0.46          
Epoch 19 (492 sec): [2187/2187] |██████████| loss: 0.0225	accuracy: 0.990	val_loss: 0.0996	val_accuracy: 0.969	Lr (e-3): 0.31          
Epoch 20 (494 sec): [2187/2187] |██████████| loss: 0.0200	accuracy: 0.992	val_loss: 0.1179	val_accuracy: 0.970	Lr (e-3): 0.31          
INFO: Early stopping!
No description has been provided for this image

Test model¶

In [16]:
def test_model(loader: torch.utils.data.DataLoader, augmentation=None, show=True):
    """:returns (y_true, y_pred), loss, accuracy, class_report, auc_score"""
    start_time = time.time()
    
    b_size = loader.batch_size
    y_true = np.zeros((len(loader) * b_size),)
    y_pred = np.zeros((len(loader) * b_size),)
    
    with torch.inference_mode():
        # loss, accuracy = run_loader(loader, type_run='test', is_train=False, augmentation=augmentation)
        
        # Metrics
        sum_loss = accuracy_count = 0
        
        # Run
        for i, (images, labels) in enumerate(loader):
            loss, outputs = forward_step(images, labels, augmentation)
            
            y_true[b_size * i:b_size * (i + 1)] = labels.numpy()
            y_pred[b_size * i:b_size * (i + 1)] = outputs

            # Metrics - Loss & Accuracy
            sum_loss += loss.item()
            correct_count = binary_accuracy_count(outputs, labels)
            accuracy_count += correct_count
    
            progress_bar(i, len(loader), prefix=f"Testing ({time.time() - start_time :.1f} sec): \t",
                         suffix=f" \tloss: {loss.item() :.3f} \taccuracy: {correct_count / len(labels):.3f}")
    
    loss, accuracy = sum_loss / len(loader), accuracy_count / len(loader.dataset)
    progress_bar(len(loader), len(loader), length=10, percent=False,
                 prefix=f"Testing ({time.time() - start_time :.1f} sec): \t",
                 suffix=f" \tloss: {loss :.5f} \taccuracy: {accuracy :.4f}" + " " * 40)
    
    class_report = show_metrics(y_true, y_pred, show=show)
    auc_score = show_roc_auc(y_true, y_pred, show=show)
    
    return (y_true, y_pred), loss, accuracy, class_report, auc_score

def show_metrics(y_true, y_pred, show=True):
    y_pred_one_hot = np.round(y_pred)
    labels = list(labels_map_reversed.keys())
    
    conf_matrix = metrics.confusion_matrix(y_true, y_pred_one_hot)
    class_report = metrics.classification_report(y_true, y_pred_one_hot, target_names=labels, output_dict=True)
    if show:
        print(metrics.classification_report(y_true, y_pred_one_hot, digits=4, target_names=labels))
        metrics.ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=labels).plot()
        plt.title(f'Confusion Matrix {model.name}')
        plt.show()
    return class_report

def show_roc_auc(y_true, y_pred, show=True):
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred)
    auc_score = metrics.roc_auc_score(y_true, y_pred)
    if show:
        plt.title(f'ROC AUC {model.name}')
        plt.plot(fpr, tpr, label=f"Auc = {auc_score*100 :.2f}")
        plt.plot([0, 1], [0, 1],'r--')
        plt.xlim([-0.03, 1.03])
        plt.ylim([-0.03, 1.03])
        plt.legend(loc=4)
        plt.grid(True)
        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')
        plt.show()
    return auc_score
In [17]:
model = load_model("ResNet-18")

# _ = test_model(loader_val)
_ = test_model(loader_test)
Testing (42.7 sec): 	[42/42] |██████████|  	loss: 0.29526 	accuracy: 0.9034                                                          
              precision    recall  f1-score   support

        Fake     0.8835    0.9637    0.9218      5421
        Real     0.9593    0.8708    0.9129      5331

    accuracy                         0.9176     10752
   macro avg     0.9214    0.9172    0.9174     10752
weighted avg     0.9211    0.9176    0.9174     10752

No description has been provided for this image
No description has been provided for this image
In [16]:
with open(os.path.join(MODEL_DIR, "stat.json")) as file:
    statistic = json.load(file)
    # print(statistic)

table_stats = {"model": [], "loss": [], "accuracy": [], "f1": [], "auc": [],
               "val_loss": [], "val_accuracy": [], "val_f1": [], "val_auc": [],
               "params": [], "params_mb": [], "mult_ads_gflops": [], "pass_size_mb": [], "total_pass_size_mb": [],
               "test_batch_size": []}

for filename in os.listdir(MODEL_DIR):
    print(filename)
    if filename.endswith(".pth") :#and filename not in statistic: # and filename not in statistic:
        model = load_model(filename[:-4])
        
        _, test_loss, test_accuracy, test_class_report, test_auc = test_model(loader_test, show=False)
        _, val_loss, val_accuracy, val_class_report, val_auc = test_model(loader_val, show=False)
        
        stat_batch_size = 128
        with torch.autocast(device_type=str(device), dtype=torch.float16, enabled=USE_AMP):
            info = info_summary(model, input_size=(stat_batch_size, 3, 256, 256), depth=3, device=device, mode="train",
                                col_names=["num_params", "mult_adds", "output_size"], row_settings=["var_names"])
        
        info_file = {
            "Loss":                round(test_loss, 5),
            "Accuracy":            round(test_accuracy * 100, 3),
            "F1-score":            round(test_class_report['weighted avg']['f1-score'], 5),
            "AUC-score":           round(test_auc, 5),
            "Val loss":            round(val_loss, 5),
            "Val accuracy":        round(val_accuracy * 100, 3),
            "Val F1-score":        round(val_class_report['weighted avg']['f1-score'], 5),
            "Val AUC-score":       round(val_auc, 5),
            f"Structure (bs={stat_batch_size})": 
            {
                "Params": info.total_params,
                "Mult-adds GFlops": round(info.total_mult_adds / 1024 ** 3, 1),
                "Params (MB)": round(info.total_param_bytes / 1024 ** 2, 2),
                "Forward/backward pass size (MB)": round(info.total_output_bytes / 1024 ** 2, 1),
                "Estimated Total Size (MB)":
                    round((info.total_output_bytes + info.total_input + info.total_param_bytes) / 1024 ** 2, 1),
            }
        }
        table_stats["model"].append(filename)
        table_stats["loss"].append(info_file["Loss"])
        table_stats["accuracy"].append(info_file["Accuracy"])
        table_stats["f1"].append(info_file["F1-score"])
        table_stats["auc"].append(info_file["AUC-score"])
        table_stats["val_loss"].append(info_file["Val loss"])
        table_stats["val_accuracy"].append(info_file["Val accuracy"])
        table_stats["val_f1"].append(info_file["Val F1-score"])
        table_stats["val_auc"].append(info_file["Val AUC-score"])
        struct_model = info_file[f"Structure (bs={stat_batch_size})"]
        table_stats["params"].append(struct_model["Params"])
        table_stats["params_mb"].append(struct_model["Params (MB)"])
        table_stats["mult_ads_gflops"].append(struct_model["Mult-adds GFlops"])
        table_stats["pass_size_mb"].append(struct_model["Forward/backward pass size (MB)"])
        table_stats["total_pass_size_mb"].append(struct_model["Estimated Total Size (MB)"])
        table_stats["test_batch_size"].append(stat_batch_size)

        statistic[filename] = info_file

        with open(os.path.join(MODEL_DIR, "stat.json"), 'w') as file:
            json.dump(dict(sorted(statistic.items())), file, indent=4)

df = pd.DataFrame.from_dict(table_stats)
df.to_csv(os.path.join(MODEL_DIR, "stat.csv"), index=False, encoding='utf-8')
0.8-SqNxt-17.pth
Testing (8.6 sec): 	[42/42] |██████████|  	loss: 0.35750 	accuracy: 0.8518                                                          
Testing (64.0 sec): 	[154/154] |██████████|  	loss: 0.15500 	accuracy: 0.9427                                                          
1.0-SqNxt-15.pth
Testing (7.7 sec): 	[42/42] |██████████|  	loss: 0.43043 	accuracy: 0.8493                                                          
Testing (25.5 sec): 	[154/154] |██████████|  	loss: 0.14233 	accuracy: 0.9535                                                          
1.0-SqNxt-17.pth
Testing (7.8 sec): 	[42/42] |██████████|  	loss: 0.52064 	accuracy: 0.8550                                                          
Testing (25.2 sec): 	[154/154] |██████████|  	loss: 0.09656 	accuracy: 0.9661                                                          
1.2-SqNxt-20.pth
Testing (8.1 sec): 	[42/42] |██████████|  	loss: 0.46883 	accuracy: 0.8639                                                          
Testing (27.4 sec): 	[154/154] |██████████|  	loss: 0.12763 	accuracy: 0.9611                                                          
1.4-SqNxt-18.pth
Testing (8.7 sec): 	[42/42] |██████████|  	loss: 0.51017 	accuracy: 0.8453                                                          
Testing (26.5 sec): 	[154/154] |██████████|  	loss: 0.11148 	accuracy: 0.9605                                                          
2.6-SqNxt-20.pth
Testing (14.1 sec): 	[42/42] |██████████|  	loss: 0.58686 	accuracy: 0.8778                                                          
Testing (42.8 sec): 	[154/154] |██████████|  	loss: 0.09308 	accuracy: 0.9682                                                          
ResNet-18.pth
Testing (8.5 sec): 	[42/42] |██████████|  	loss: 0.29705 	accuracy: 0.9031                                                          
Testing (26.9 sec): 	[154/154] |██████████|  	loss: 0.08534 	accuracy: 0.9732                                                          
ResNet-18pre.pth
Testing (8.5 sec): 	[42/42] |██████████|  	loss: 0.29383 	accuracy: 0.8894                                                          
Testing (27.1 sec): 	[154/154] |██████████|  	loss: 0.08218 	accuracy: 0.9692                                                          
ResNet-34pre_nogit.pth
Testing (10.5 sec): 	[42/42] |██████████|  	loss: 0.46238 	accuracy: 0.8720                                                          
Testing (35.3 sec): 	[154/154] |██████████|  	loss: 0.05314 	accuracy: 0.9821                                                          
ResNet-50pre_nogit.pth
Testing (18.0 sec): 	[42/42] |██████████|  	loss: 0.51507 	accuracy: 0.8713                                                          
Testing (55.1 sec): 	[154/154] |██████████|  	loss: 0.05762 	accuracy: 0.9784                                                          
ResNet-50_nogit.pth
Testing (16.9 sec): 	[42/42] |██████████|  	loss: 0.29958 	accuracy: 0.8990                                                          
Testing (53.4 sec): 	[154/154] |██████████|  	loss: 0.07499 	accuracy: 0.9752                                                          
stat.csv
stat.json
ViT-4-12-512-1024_nogit.pth
Testing (7.9 sec): 	[42/42] |██████████|  	loss: 0.55986 	accuracy: 0.6953                                                          
Testing (26.5 sec): 	[154/154] |██████████|  	loss: 0.52846 	accuracy: 0.7322                                                          
In [17]:
df = pd.read_csv(os.path.join(MODEL_DIR, "stat.csv"), index_col=0)
df[["accuracy", "val_accuracy"]] /= 100
df
Out[17]:
loss accuracy f1 auc val_loss val_accuracy val_f1 val_auc params params_mb mult_ads_gflops pass_size_mb total_pass_size_mb test_batch_size
model
0.8-SqNxt-17.pth 0.35750 0.85185 0.86494 0.94790 0.15500 0.94271 0.94279 0.98661 348022 0.95 18.6 3351.6 3448.5 128
1.0-SqNxt-15.pth 0.43043 0.84928 0.86202 0.95032 0.14233 0.95354 0.95362 0.98989 276433 0.91 25.7 3824.6 3921.5 128
1.0-SqNxt-17.pth 0.52064 0.85505 0.86738 0.95521 0.09656 0.96606 0.96616 0.99490 288481 0.95 27.9 4341.7 4438.7 128
1.2-SqNxt-20.pth 0.46883 0.86393 0.87689 0.95977 0.12763 0.96107 0.96116 0.99230 474593 1.60 41.9 5864.3 5961.9 128
1.4-SqNxt-18.pth 0.51017 0.84525 0.85708 0.95227 0.11148 0.96051 0.96061 0.99358 1536162 4.72 50.7 5192.1 5292.8 128
2.6-SqNxt-20.pth 0.58686 0.87776 0.89114 0.94616 0.09308 0.96825 0.96834 0.99555 5546447 17.23 170.7 10806.3 10919.5 128
ResNet-18.pth 0.29705 0.90312 0.91712 0.98142 0.08534 0.97324 0.97334 0.99639 11690025 42.64 282.4 3168.0 3306.6 128
ResNet-18pre.pth 0.29383 0.88939 0.90325 0.96315 0.08218 0.96924 0.96931 0.99742 11690025 42.64 282.4 3168.0 3306.6 128
ResNet-34pre_nogit.pth 0.46238 0.87199 0.88460 0.97266 0.05314 0.98214 0.98224 0.99826 21798185 81.20 570.4 4768.0 4945.2 128
ResNet-50pre_nogit.pth 0.51507 0.87135 0.88422 0.95932 0.05762 0.97844 0.97854 0.99793 25559081 89.68 636.4 14176.0 14361.7 128
ResNet-50_nogit.pth 0.29958 0.89900 0.91287 0.97842 0.07499 0.97522 0.97532 0.99676 25559081 89.68 636.4 14176.0 14361.7 128
ViT-4-12-512-1024_nogit.pth 0.55986 0.69527 0.70619 0.78240 0.52846 0.73220 0.73206 0.81186 12082177 46.09 1.4 1072.0 1214.1 128
In [18]:
def show_bar_plot(title: str):
    plt.title(title)
    plt.xlabel("Models")
    plt.ylabel("Value")
    plt.grid(axis="y")
    plt.show()


labels_rot = 60

# Loss
df.sort_values(by=["loss"], ascending=False, inplace=True)  # df.sort_index(inplace=True)
df[["loss", "val_loss"]].plot.bar(rot=labels_rot, figsize=(10, 6))
show_bar_plot("Loss (less is better)")

# Test Metric
df.sort_values(by=["accuracy"], ascending=True, inplace=True)  # df.sort_index(inplace=True)
df[["accuracy", "f1", "auc"]].plot.bar(rot=labels_rot, figsize=(15, 6))
show_bar_plot("Metrics (more is better)")

# Val Metric
df[["val_accuracy", "val_f1", "val_auc"]].plot.bar(rot=labels_rot, figsize=(15, 6))
show_bar_plot("Metrics validation (more is better)")

# Param
df.sort_index(inplace=True)
df[["params"]].plot.bar(rot=labels_rot, figsize=(9, 4))
show_bar_plot("Params of models")

# VRAM
df[["pass_size_mb", "total_pass_size_mb"]].plot.bar(rot=labels_rot, figsize=(9, 4))
show_bar_plot("VRAM usings")

# Computations speed
df[["mult_ads_gflops"]].plot.bar(rot=labels_rot, figsize=(9, 4))
show_bar_plot("Mult-adds GFlops (speed - less is better)")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [16]:
model = load_model("ResNet-18")
show_from_loader(loader_val, rows=5, cols=12, is_show_predict=True)
No description has been provided for this image
Min=0.0 Max=1.0 Mean=0.3857888877391815 Std=0.24290801584720612

Production¶

Использование модели в рабочем режиме

Модель берет изображения из TEST_IMGS_DIR а предсказание модели помещает в файл prediction.csv

In [17]:
TEST_IMGS_DIR = "example/imgs/1"
TEST_DIR = "example"

model = load_model("ResNet-18")

# Loader
dataset_end_test = torchvision.datasets.ImageFolder(
    root=os.path.split(TEST_IMGS_DIR)[0],
    transform=transforms.Compose([transform_base, transforms.Resize((256, 256)),]))
loader_end_test = torch.utils.data.DataLoader(dataset_end_test, batch_size=128, shuffle=False, drop_last=False, pin_memory=True,)

print(sorted(os.listdir(TEST_IMGS_DIR)))
print(len(loader_end_test.dataset))
['fake_8379.jpg', 'fake_8380.jpg', 'fake_8381.jpg', 'fake_8382.jpg', 'fake_8383.jpg', 'fake_8384.jpg', 'fake_8385.jpg', 'fake_8386.jpg', 'fake_8387.jpg', 'fake_8388.jpg', 'fake_8389.jpg', 'fake_8390.jpg', 'fake_8391.jpg', 'fake_8392.jpg', 'fake_8393.jpg', 'fake_8394.jpg', 'fake_8395.jpg', 'fake_8396.jpg', 'fake_8397.jpg', 'fake_8398.jpg', 'fake_8399.jpg', 'fake_8400.jpg', 'fake_8401.jpg', 'fake_8402.jpg', 'fake_8403.jpg', 'fake_8404.jpg', 'fake_8405.jpg', 'fake_8406.jpg', 'fake_8407.jpg', 'fake_8408.jpg', 'fake_8409.jpg', 'fake_8410.jpg', 'fake_8411.jpg', 'fake_8412.jpg', 'fake_8413.jpg', 'fake_8414.jpg', 'fake_8415.jpg', 'fake_8416.jpg', 'fake_8417.jpg', 'fake_8418.jpg', 'fake_8419.jpg', 'fake_8420.jpg', 'fake_8421.jpg', 'fake_8422.jpg', 'fake_8423.jpg', 'fake_8424.jpg', 'fake_8425.jpg', 'fake_8426.jpg', 'fake_8427.jpg', 'fake_8428.jpg', 'fake_8429.jpg', 'fake_8430.jpg', 'fake_8431.jpg', 'fake_8432.jpg', 'fake_8433.jpg', 'fake_8434.jpg', 'fake_8435.jpg', 'fake_8436.jpg', 'fake_8437.jpg', 'fake_8438.jpg', 'fake_8439.jpg', 'fake_8440.jpg', 'fake_8441.jpg', 'fake_8442.jpg', 'fake_8443.jpg', 'fake_8444.jpg', 'fake_8445.jpg', 'fake_8446.jpg', 'fake_8447.jpg', 'fake_8448.jpg', 'fake_8449.jpg', 'fake_8450.jpg', 'fake_8451.jpg', 'fake_8452.jpg', 'fake_8453.jpg', 'fake_8454.jpg', 'fake_8455.jpg', 'fake_8456.jpg', 'fake_8457.jpg', 'fake_8458.jpg', 'fake_8459.jpg', 'fake_8460.jpg', 'fake_8461.jpg', 'fake_8462.jpg', 'fake_8463.jpg', 'fake_8464.jpg', 'fake_8465.jpg', 'fake_8466.jpg', 'fake_8467.jpg', 'fake_8468.jpg', 'fake_8469.jpg', 'fake_8470.jpg', 'fake_8471.jpg', 'fake_8472.jpg', 'fake_8473.jpg', 'fake_8474.jpg', 'fake_8475.jpg', 'fake_8476.jpg', 'fake_8477.jpg', 'fake_8478.jpg', 'fake_8479.jpg', 'fake_8480.jpg', 'fake_8481.jpg', 'fake_8482.jpg', 'fake_8483.jpg', 'fake_8484.jpg', 'fake_8485.jpg', 'fake_8486.jpg', 'fake_8487.jpg', 'fake_8488.jpg', 'fake_8489.jpg', 'fake_8490.jpg', 'fake_8491.jpg', 'fake_8492.jpg', 'fake_8493.jpg', 'fake_8494.jpg', 'fake_8495.jpg', 'fake_8496.jpg', 'fake_8497.jpg', 'fake_8498.jpg', 'fake_8499.jpg', 'fake_8500.jpg', 'fake_8501.jpg', 'fake_8502.jpg', 'fake_8503.jpg', 'fake_8504.jpg', 'fake_8505.jpg', 'fake_8506.jpg', 'fake_8507.jpg', 'fake_8508.jpg', 'fake_8509.jpg', 'fake_8510.jpg', 'fake_8511.jpg', 'fake_8512.jpg', 'fake_8513.jpg', 'fake_8514.jpg', 'fake_8515.jpg', 'fake_8516.jpg', 'fake_8517.jpg', 'fake_8518.jpg', 'fake_8519.jpg', 'fake_8520.jpg', 'fake_8521.jpg', 'fake_8522.jpg', 'fake_8523.jpg', 'fake_8524.jpg', 'fake_8525.jpg', 'fake_8526.jpg', 'fake_8527.jpg', 'fake_8528.jpg', 'fake_8529.jpg', 'fake_8530.jpg', 'fake_8531.jpg', 'fake_8532.jpg', 'fake_8533.jpg', 'fake_8534.jpg', 'fake_8535.jpg', 'fake_8536.jpg', 'fake_8537.jpg', 'fake_8538.jpg', 'fake_8539.jpg', 'fake_8540.jpg', 'fake_8541.jpg', 'fake_8542.jpg', 'fake_8543.jpg', 'fake_8544.jpg', 'fake_8545.jpg', 'fake_8546.jpg', 'fake_8547.jpg', 'fake_8548.jpg', 'fake_8549.jpg', 'fake_8550.jpg', 'fake_8551.jpg', 'fake_8552.jpg', 'fake_8553.jpg', 'fake_8554.jpg', 'fake_8555.jpg', 'fake_8556.jpg', 'fake_8557.jpg', 'fake_8558.jpg', 'real_369.jpg', 'real_370.jpg', 'real_371.jpg', 'real_372.jpg', 'real_373.jpg', 'real_374.jpg', 'real_375.jpg', 'real_376.jpg', 'real_377.jpg', 'real_378.jpg', 'real_379.jpg', 'real_380.jpg', 'real_381.jpg', 'real_382.jpg', 'real_383.jpg', 'real_384.jpg', 'real_385.jpg', 'real_386.jpg', 'real_387.jpg', 'real_388.jpg', 'real_389.jpg', 'real_390.jpg', 'real_391.jpg', 'real_392.jpg', 'real_393.jpg', 'real_394.jpg', 'real_395.jpg', 'real_405.jpg', 'real_406.jpg', 'real_407.jpg', 'real_408.jpg', 'real_409.jpg', 'real_410.jpg', 'real_411.jpg', 'real_412.jpg', 'real_413.jpg', 'real_414.jpg', 'real_415.jpg', 'real_416.jpg', 'real_417.jpg', 'real_418.jpg', 'real_419.jpg', 'real_420.jpg', 'real_421.jpg', 'real_422.jpg', 'real_423.jpg', 'real_424.jpg', 'real_425.jpg', 'real_426.jpg', 'real_427.jpg', 'real_428.jpg', 'real_429.jpg', 'real_430.jpg', 'real_431.jpg', 'real_432.jpg', 'real_433.jpg', 'real_434.jpg', 'real_435.jpg', 'real_436.jpg', 'real_437.jpg', 'real_438.jpg', 'real_439.jpg', 'real_440.jpg', 'real_441.jpg', 'real_442.jpg', 'real_443.jpg', 'real_444.jpg', 'real_445.jpg', 'real_446.jpg', 'real_447.jpg', 'real_448.jpg', 'real_449.jpg', 'real_450.jpg', 'real_451.jpg', 'real_452.jpg', 'real_453.jpg', 'real_454.jpg', 'real_455.jpg', 'real_456.jpg', 'real_457.jpg', 'real_458.jpg', 'real_459.jpg', 'real_460.jpg', 'real_461.jpg', 'real_462.jpg', 'real_463.jpg', 'real_464.jpg', 'real_465.jpg', 'real_466.jpg', 'real_467.jpg', 'real_468.jpg', 'real_469.jpg', 'real_470.jpg', 'real_471.jpg', 'real_472.jpg', 'real_473.jpg', 'real_474.jpg', 'real_475.jpg', 'real_476.jpg', 'real_477.jpg', 'real_478.jpg', 'real_479.jpg', 'real_480.jpg', 'real_481.jpg', 'real_482.jpg', 'real_483.jpg', 'real_484.jpg', 'real_485.jpg', 'real_486.jpg', 'real_487.jpg', 'real_488.jpg', 'real_489.jpg', 'real_490.jpg', 'real_491.jpg', 'real_492.jpg', 'real_493.jpg', 'real_494.jpg', 'real_495.jpg', 'real_496.jpg', 'real_497.jpg', 'real_498.jpg', 'real_499.jpg', 'real_500.jpg', 'real_501.jpg', 'real_502.jpg', 'real_503.jpg', 'real_504.jpg', 'real_505.jpg', 'real_506.jpg', 'real_507.jpg', 'real_508.jpg', 'real_509.jpg', 'real_510.jpg', 'real_511.jpg', 'real_512.jpg', 'real_513.jpg', 'real_514.jpg', 'real_515.jpg', 'real_516.jpg', 'real_517.jpg', 'real_518.jpg', 'real_519.jpg', 'real_520.jpg', 'real_521.jpg', 'real_522.jpg', 'real_523.jpg', 'real_524.jpg', 'real_525.jpg', 'real_526.jpg', 'real_527.jpg', 'real_528.jpg', 'real_529.jpg', 'real_530.jpg', 'real_531.jpg', 'real_532.jpg', 'real_533.jpg', 'real_534.jpg', 'real_535.jpg', 'real_536.jpg', 'real_537.jpg', 'real_538.jpg', 'real_539.jpg', 'real_540.jpg', 'real_541.jpg', 'real_542.jpg', 'real_543.jpg', 'real_544.jpg', 'real_545.jpg', 'real_546.jpg', 'real_547.jpg', 'real_548.jpg', 'real_549.jpg', 'real_550.jpg', 'real_551.jpg', 'real_552.jpg', 'real_553.jpg', 'real_554.jpg', 'real_555.jpg', 'real_556.jpg', 'real_557.jpg']
360
In [24]:
# # Rename
# fnames_history = {}
# filenames = os.listdir(TEST_IMGS_DIR)
# for i, name in zip(random.sample(range(len(filenames)), len(filenames)), sorted(filenames)):
#     os.rename(f"{TEST_IMGS_DIR}/{name}", f"{TEST_IMGS_DIR}/img_{i}.jpg")
#     fnames_history[f"img_{i}.jpg"] = name
# 
# for i in range(len(fnames_history)):
#     print(i, fnames_history[f"img_{i}.jpg"])
In [18]:
show_from_loader(loader_end_test, rows=2, cols=10)
No description has been provided for this image
Min=0.0 Max=1.0 Mean=0.4136466383934021 Std=0.2551622688770294
In [19]:
y_pred = np.zeros((len(loader_end_test.dataset)),)
start_idx = 0

with torch.inference_mode():
    for i, (images, labels) in enumerate(loader_end_test):
        _, outputs = forward_step(images, labels)
        y_pred[start_idx:start_idx+len(outputs)] = outputs
        start_idx += len(outputs)

print(y_pred.shape, y_pred.dtype)

df_res = pd.DataFrame.from_dict({
    "filename": sorted(os.listdir('example/imgs/1')),
    "pred": y_pred.round().astype(int),
    "value": y_pred.round(5)
})

df_res.to_csv(os.path.join(TEST_DIR, 'prediction.csv'), index=False, encoding='utf-8')
(360,) float64
In [28]:
# # Rename back
# for name, old_name in fnames_history.items():
#     os.rename(f"{TEST_IMGS_DIR}/{name}", f"{TEST_IMGS_DIR}/{old_name}")
In [20]:
df_res
Out[20]:
filename pred value
0 fake_8379.jpg 0 0.00000
1 fake_8380.jpg 0 0.00000
2 fake_8381.jpg 0 0.00008
3 fake_8382.jpg 0 0.00000
4 fake_8383.jpg 0 0.02104
... ... ... ...
355 real_553.jpg 1 1.00000
356 real_554.jpg 1 0.99854
357 real_555.jpg 1 1.00000
358 real_556.jpg 1 1.00000
359 real_557.jpg 1 1.00000

360 rows × 3 columns

In [18]:
2 + 2
Out[18]:
4